{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7fc6e1cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(os.path.abspath('../../code'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcee14a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from pathlib import Path\n",
    "\n",
    "import regex\n",
    "from collections import Counter\n",
    "\n",
    "import numpy as np\n",
    "np.set_printoptions(formatter={'float': lambda x: \"{0:0.3f}\".format(x)})\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_columns', 15)\n",
    "pd.set_option('display.width', 320)\n",
    "\n",
    "from utils.io import read_label_config\n",
    "from utils.corpus import JsonlinesAnnotationsCorpus\n",
    "from utils.bsc_model import BSCModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "441cec08",
   "metadata": {},
   "outputs": [],
   "source": [
    "from types import SimpleNamespace\n",
    "\n",
    "args = SimpleNamespace()\n",
    "args.label_config_file = '../../data/annotation/doccano_label_config.json'\n",
    "args.input_file = '../../data/annotation/parsed/uk-manifestos_annotations.jsonl'\n",
    "\n",
    "args.data_path = '../../data/annotation/annotations/'\n",
    "args.data_folder_pattern = 'uk-manifestos'\n",
    "args.data_file_format = 'csv'\n",
    "\n",
    "args.output_file = '../../data/annotation/labeled/uk-manifestos_all_labeled.jsonl'\n",
    "args.overwrite_output = False\n",
    "\n",
    "args.verbose = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9940ee7",
   "metadata": {},
   "source": [
    "## Read the annotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0485a7f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "job\n",
      "group-mentions-annotation-uk-manifestos-round-02         3377\n",
      "group-mentions-annotation-uk-manifestos-round-03         1557\n",
      "group-mentions-annotation-uk-manifestos-round-01         1400\n",
      "group-mentions-annotation-uk-manifestos-other-parties    1362\n",
      "group-mentions-annotation-uk-manifestos-2017+19           900\n",
      "Name: count, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "fps = [str(fp) for fp in Path(args.data_path).glob(f'*{args.data_folder_pattern}*/*.{args.data_file_format}')]\n",
    "\n",
    "# read metadata\n",
    "metadata = pd.concat({fp.split('/')[-1].replace('_ids.csv', ''): pd.read_csv(fp) for fp in fps}, axis=0)\n",
    "metadata['job'] = metadata.index.get_level_values(0)\n",
    "metadata.reset_index(drop=True, inplace=True)\n",
    "print(metadata.job.value_counts())\n",
    "jobs = metadata.job.unique().tolist()\n",
    "jobs.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c9b210fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8576/8576 [01:08<00:00, 124.38it/s]\n"
     ]
    }
   ],
   "source": [
    "cat2code = read_label_config(args.label_config_file)\n",
    "\n",
    "acorp = JsonlinesAnnotationsCorpus(cat2code)\n",
    "acorp.load_from_jsonlines(args.input_file, verbose=args.verbose)\n",
    "# NOTE: sorry that this function is soo slow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "56fd658e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# docs =  8576\n",
      "# gold items = 610\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'O': 0,\n",
       " 'I-social group': 1,\n",
       " 'B-social group': 7,\n",
       " 'I-political group': 2,\n",
       " 'B-political group': 8,\n",
       " 'I-political institution': 3,\n",
       " 'B-political institution': 9,\n",
       " 'I-organization, public institution, or collective actor': 4,\n",
       " 'B-organization, public institution, or collective actor': 10,\n",
       " 'I-implicit social group reference': 5,\n",
       " 'B-implicit social group reference': 11,\n",
       " 'I-unsure': 6,\n",
       " 'B-unsure': 12}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_docs = acorp.ndocs\n",
    "n_gold_labeled = len([1 for doc in acorp.docs if doc.n_labels > 0])\n",
    "print('# docs = ', n_docs)\n",
    "print('# gold items =', n_gold_labeled)\n",
    "acorp.label_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "60a595f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ensure that 'unsure' never in GOLD annotations\n",
    "for doc in acorp.docs:\n",
    "    if doc.n_labels > 0 and cat2code['B-unsure'] in doc.labels['GOLD']:\n",
    "        print(doc.id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "97d23e73",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({0: 9974,\n",
       "         1: 894,\n",
       "         4: 468,\n",
       "         3: 346,\n",
       "         7: 326,\n",
       "         5: 311,\n",
       "         10: 242,\n",
       "         9: 185,\n",
       "         11: 150,\n",
       "         2: 147,\n",
       "         8: 110})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# count label types in GOlD\n",
    "# ensure that 'unsure' never in GOLD annotations\n",
    "gold_types = Counter()\n",
    "for doc in acorp.docs:\n",
    "    if 'GOLD' in doc.labels:\n",
    "        gold_types.update(doc.labels['GOLD'].tolist())\n",
    "gold_types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "278841e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop 'unsure' labels\n",
    "def recode_labels(codes):\n",
    "    out = codes.copy()\n",
    "    # set \"unsure\" to outside\n",
    "    out[codes ==  6] = 0\n",
    "    out[codes == 12] = 0\n",
    "    # reset all others\n",
    "    out[out>6] -= 1\n",
    "    return(out)\n",
    "\n",
    "for i in range(acorp.ndocs):\n",
    "    for j in range(acorp.docs[i].n_annotations):\n",
    "        acorp.docs[i].annotations[ acorp.docs[i].annotators[j] ] = recode_labels(acorp.docs[i].annotations[ acorp.docs[i].annotators[j] ])\n",
    "        for k in acorp.docs[i].labels.keys():\n",
    "            acorp.docs[i].labels[k] = recode_labels(acorp.docs[i].labels[k])\n",
    "\n",
    "\n",
    "idx = acorp.label_map.pop('B-unsure')\n",
    "acorp.label_map_inv.pop(idx)\n",
    "idx = acorp.label_map.pop('I-unsure')\n",
    "acorp.label_map_inv.pop(idx)\n",
    "\n",
    "# [print(c, 't', l) for c, l in acorp.label_map_inv.items()]\n",
    "for c in range(7,12): acorp.label_map[acorp.label_map_inv[c]] -= 1\n",
    "# [print(c, 't', l) for l, c in acorp.label_map.items()]\n",
    "\n",
    "acorp.label_map_inv = {c: l for l, c in acorp.label_map.items()}\n",
    "# [print(c, 't', l) for c, l in acorp.label_map_inv.items()]\n",
    "\n",
    "len(acorp.label_map) == len(acorp.label_map_inv)\n",
    "\n",
    "acorp.inside_labels = list(range(1, 6))\n",
    "acorp.beginning_labels = list(range(6, 11))\n",
    "\n",
    "acorp.ndocs = len(acorp.docs)\n",
    "acorp.doc_ids = [doc.id for doc in acorp.docs]\n",
    "acorp.doc_id2idx = {doc.id: i for i, doc in enumerate(acorp.docs)}\n",
    "acorp.doc_idx2id = {i: doc.id for i, doc in enumerate(acorp.docs)}\n",
    "\n",
    "acorp.annotator_label_counts = acorp._count_annotator_labels()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8b9c3ace",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>emarie</th>\n",
       "      <th>sjasmin</th>\n",
       "      <th>gold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>104345</td>\n",
       "      <td>107456</td>\n",
       "      <td>10300.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4073</td>\n",
       "      <td>3188</td>\n",
       "      <td>894.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>2126</td>\n",
       "      <td>1977</td>\n",
       "      <td>110.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>777</td>\n",
       "      <td>643</td>\n",
       "      <td>147.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>942</td>\n",
       "      <td>896</td>\n",
       "      <td>185.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1886</td>\n",
       "      <td>1724</td>\n",
       "      <td>346.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1281</td>\n",
       "      <td>1439</td>\n",
       "      <td>242.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2122</td>\n",
       "      <td>1766</td>\n",
       "      <td>468.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1333</td>\n",
       "      <td>1370</td>\n",
       "      <td>150.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>913</td>\n",
       "      <td>833</td>\n",
       "      <td>311.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>774</td>\n",
       "      <td>813</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    emarie  sjasmin     gold\n",
       "0   104345   107456  10300.0\n",
       "1     4073     3188    894.0\n",
       "6     2126     1977    110.0\n",
       "2      777      643    147.0\n",
       "7      942      896    185.0\n",
       "3     1886     1724    346.0\n",
       "8     1281     1439    242.0\n",
       "4     2122     1766    468.0\n",
       "9     1333     1370    150.0\n",
       "5      913      833    311.0\n",
       "10     774      813      NaN"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "annotations = acorp.annotator_label_counts\n",
    "gold = Counter()\n",
    "for labs in [doc.labels['GOLD'].tolist() for doc in acorp.docs if 'GOLD' in doc.labels]:\n",
    "    gold.update(labs)\n",
    "pd.DataFrame(annotations).join(pd.DataFrame({'gold': dict(gold)})).loc[acorp.label_map.values()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b31fb2de",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9692015895953757\n"
     ]
    }
   ],
   "source": [
    "# define custom token cleaning function\n",
    "def clean_tokens(x):\n",
    "\n",
    "    # insert special characters\n",
    "    if regex.match(r\"^\\p{Sc}\", x):\n",
    "        return(\"<MONEY>\")\n",
    "\n",
    "    if regex.match(r\"^\\d+([.,/-]\\d+)*\\w*$\", x):\n",
    "        return(\"<DIGITS>\")\n",
    "\n",
    "    return(x)\n",
    "\n",
    "# clean docs' tokens (in place) and collect results in Counter object\n",
    "cleaned_vocab = Counter()\n",
    "for i in range(acorp.ndocs):\n",
    "    cleaned_vocab.update(acorp.docs[i].clean_tokens(fun=clean_tokens))\n",
    "\n",
    "# reduces vocab size somewhat!\n",
    "print(len(cleaned_vocab)/len(acorp.vocab))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c0fa548",
   "metadata": {},
   "source": [
    "## Annotation aggregation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b751d6c",
   "metadata": {},
   "source": [
    "### Prepare the Baysian sequene combination (BSC) sequence annotation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "35f71c33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parallel can run 10 jobs simultaneously, with 10 cores\n"
     ]
    }
   ],
   "source": [
    "amodel = BSCModel(acorp, max_iter = 30, gold_labels='GOLD', verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4616fe2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# label classes: 11\n",
      "# annotators:    2\n",
      "# docs: 8576\n",
      "# tokens: 179641\n",
      "# docs with gold labels: 610\n",
      "# tokens with gold labels: 2853\n"
     ]
    }
   ],
   "source": [
    "print('# label classes:', amodel.num_classes)\n",
    "print('# annotators:   ', amodel.num_annotators)\n",
    "print('# docs:', amodel.num_docs)\n",
    "print('# tokens:', amodel.num_tokens)\n",
    "print('# docs with gold labels:', (amodel.gold[amodel.doc_start == 1] > -1).sum())\n",
    "print('# tokens with gold labels:', (amodel.gold > 0).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e1b3fab4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.0"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# inspect the alpha prior\n",
    "alpha0 = amodel.model.A.alpha0.copy()\n",
    "# note: default prior belief is that annotators assign correct label 2 out of 3 times\n",
    "alpha0[0,0]/alpha0[0,1:].sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "3a568c13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                    =>O  =>I-soc  =>I-pol  =>I-pol  =>I-org  =>I-imp  =>B-soc  =>B-pol  =>B-pol  =>B-org  =>B-imp\n",
      "O                                                   6.0      0.0      0.0      0.0      0.0      0.0      1.0      1.0      1.0      1.0      1.0\n",
      "I-social group                                      1.0      1.0      0.0      0.0      0.0      0.0      0.0      1.0      1.0      1.0      1.0\n",
      "I-political group                                   1.0      0.0      1.0      0.0      0.0      0.0      1.0      0.0      1.0      1.0      1.0\n",
      "I-political institution                             1.0      0.0      0.0      1.0      0.0      0.0      1.0      1.0      0.0      1.0      1.0\n",
      "I-organization, public institution, or collecti...  1.0      0.0      0.0      0.0      1.0      0.0      1.0      1.0      1.0      0.0      1.0\n",
      "I-implicit social group reference                   1.0      0.0      0.0      0.0      0.0      1.0      1.0      1.0      1.0      1.0      0.0\n",
      "B-social group                                      1.0      5.0      0.0      0.0      0.0      0.0      0.0      1.0      1.0      1.0      1.0\n",
      "B-political group                                   1.0      0.0      5.0      0.0      0.0      0.0      1.0      0.0      1.0      1.0      1.0\n",
      "B-political institution                             1.0      0.0      0.0      5.0      0.0      0.0      1.0      1.0      0.0      1.0      1.0\n",
      "B-organization, public institution, or collecti...  1.0      0.0      0.0      0.0      5.0      0.0      1.0      1.0      1.0      0.0      1.0\n",
      "B-implicit social group reference                   1.0      0.0      0.0      0.0      0.0      5.0      1.0      1.0      1.0      1.0      0.0\n"
     ]
    }
   ],
   "source": [
    "# inspect transitions prior\n",
    "new_beta0 = amodel.model.LM.beta0.copy()\n",
    "cats = [k for k, v in sorted(acorp.label_map.items(), key=lambda item: item[1])]\n",
    "tmp = pd.DataFrame(new_beta0.round(2), index = cats, columns = ['=>'+c[:5]  for c in cats])\n",
    "# note: read this by rows and columns indicate label categories 0...12 and the cell (i, j) indicates the prior from\n",
    "#       probability that the label in column j follows the label in row i\n",
    "# for examples:\n",
    "# after the \"outside\" label, only itself and B-* labels are allowed\n",
    "tmp.iloc[[0]]\n",
    "# after the \"I-social group\" label, only itself, O, or the other B-* labels are allowed\n",
    "tmp.iloc[[1]]\n",
    "new_beta0[1, 6] = 1e-12\n",
    "# after the \"B-social group\" label, O, \"I-social group\", or any of the other B-* labels are allowed\n",
    "tmp.iloc[[6]]\n",
    "new_beta0[6, 6] = 1e-12\n",
    "\n",
    "# apply logic to each type\n",
    "for c in range(2, 6):\n",
    "    new_beta0[c, c+5] = 1e-12\n",
    "    new_beta0[c+5, c+5] = 1e-12\n",
    "\n",
    "# verify\n",
    "print(pd.DataFrame(new_beta0.round(2), index = cats, columns = ['=>'+c[:5]  for c in cats]))\n",
    "\n",
    "# reset\n",
    "amodel.reset_label_transitions_prior(new_beta0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f24399ff",
   "metadata": {},
   "source": [
    "### Fit the BSC model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ba4f1ebe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BSC: run() called with annotation matrix with shape = (179641, 2)\n",
      "BC iteration 0 in progress\n",
      "BC iteration 0: computed label probabilities\n",
      "BC iteration 0: updated label model\n",
      "BC iteration 0: updated worker models\n",
      "BC iteration 1 in progress\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/hlicht/miniforge3/envs/group_mention_detection/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py:752: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BAC iteration 1: completed forward pass\n",
      "BAC iteration 1: completed backward pass\n",
      "BC iteration 1: computed label probabilities\n",
      "BC iteration 1: updated label model\n",
      "BC iteration 1: updated worker models\n",
      "Computing LB=-2565719.4905: label model and labels=-392970.9609, annotator model=-1004.0000, features=-2171744.4983\n",
      "BSC: max. difference at iteration 2: inf\n",
      "BC iteration 2 in progress\n",
      "BAC iteration 2: completed forward pass\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[20], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# fit model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mamodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_predict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Dropbox/papers/group_mention_detection/replication/code/utils/bsc_model.py:111\u001b[0m, in \u001b[0;36mBSCModel.fit_predict\u001b[0;34m(self, refit, **kwargs)\u001b[0m\n\u001b[1;32m    108\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mruntime_ \u001b[38;5;241m=\u001b[39m timeit\u001b[38;5;241m.\u001b[39mdefault_timer()\n\u001b[1;32m    109\u001b[0m \u001b[38;5;66;03m# fit and predict\u001b[39;00m\n\u001b[1;32m    110\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mEt_, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabels_, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabel_probs_ \u001b[38;5;241m=\u001b[39m \\\n\u001b[0;32m--> 111\u001b[0m \t\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_predict\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdoc_start\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mruntime_ \u001b[38;5;241m=\u001b[39m timeit\u001b[38;5;241m.\u001b[39mdefault_timer() \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mruntime_\n\u001b[1;32m    114\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfitted \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "File \u001b[0;32m~/Dropbox/papers/group_mention_detection/replication/code/utils/bayesian_combination/bayesian_combination.py:255\u001b[0m, in \u001b[0;36mBC.fit_predict\u001b[0;34m(self, C, doc_start, features, dev_sentences, gold_labels)\u001b[0m\n\u001b[1;32m    253\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mEt_old \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mEt\n\u001b[1;32m    254\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39miter \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 255\u001b[0m \t\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mEt \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate_t\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparallel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mC_data\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    256\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    257\u001b[0m \t\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mEt \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLM\u001b[38;5;241m.\u001b[39minit_t(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mC, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdoc_start, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mA, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mblanks, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgold)\n",
      "File \u001b[0;32m~/Dropbox/papers/group_mention_detection/replication/code/utils/bayesian_combination/label_models/markov_label_model.py:125\u001b[0m, in \u001b[0;36mMarkovLabelModel.update_t\u001b[0;34m(self, parallel, C_data)\u001b[0m\n\u001b[1;32m    122\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose:\n\u001b[1;32m    123\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBAC iteration \u001b[39m\u001b[38;5;132;01m%i\u001b[39;00m\u001b[38;5;124m: completed forward pass\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39miter)\n\u001b[0;32m--> 125\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parallel_backward_pass\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    126\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose:\n\u001b[1;32m    127\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBAC iteration \u001b[39m\u001b[38;5;132;01m%i\u001b[39;00m\u001b[38;5;124m: completed backward pass\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39miter)\n",
      "File \u001b[0;32m~/Dropbox/papers/group_mention_detection/replication/code/utils/bayesian_combination/label_models/markov_label_model.py:198\u001b[0m, in \u001b[0;36mMarkovLabelModel._parallel_backward_pass\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    193\u001b[0m     res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparallel(delayed(_doc_backward_pass)(d, doc, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mC_data_by_doc[d], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mblanks_by_doc[d],\n\u001b[1;32m    194\u001b[0m                                                     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlnB_by_doc[d], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mL,\n\u001b[1;32m    195\u001b[0m                                                     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mA, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaling[d]\n\u001b[1;32m    196\u001b[0m                                             ) \u001b[38;5;28;01mfor\u001b[39;00m d, doc \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mC_by_doc))\n\u001b[1;32m    197\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 198\u001b[0m     res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparallel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdelayed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_doc_backward_pass\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43md\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdoc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mblanks_by_doc\u001b[49m\u001b[43m[\u001b[49m\u001b[43md\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    199\u001b[0m \u001b[43m                                                    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlnB_by_doc\u001b[49m\u001b[43m[\u001b[49m\u001b[43md\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mL\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaling\u001b[49m\u001b[43m[\u001b[49m\u001b[43md\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    200\u001b[0m \u001b[43m                                            \u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdoc\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mC_by_doc\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    202\u001b[0m \u001b[38;5;66;03m# reformat results\u001b[39;00m\n\u001b[1;32m    203\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlnLambda \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mconcatenate(res, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n",
      "File \u001b[0;32m~/miniforge3/envs/group_mention_detection/lib/python3.10/site-packages/joblib/parallel.py:2007\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   2001\u001b[0m \u001b[38;5;66;03m# The first item from the output is blank, but it makes the interpreter\u001b[39;00m\n\u001b[1;32m   2002\u001b[0m \u001b[38;5;66;03m# progress until it enters the Try/Except block of the generator and\u001b[39;00m\n\u001b[1;32m   2003\u001b[0m \u001b[38;5;66;03m# reaches the first `yield` statement. This starts the asynchronous\u001b[39;00m\n\u001b[1;32m   2004\u001b[0m \u001b[38;5;66;03m# dispatch of the tasks to the workers.\u001b[39;00m\n\u001b[1;32m   2005\u001b[0m \u001b[38;5;28mnext\u001b[39m(output)\n\u001b[0;32m-> 2007\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_generator \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniforge3/envs/group_mention_detection/lib/python3.10/site-packages/joblib/parallel.py:1650\u001b[0m, in \u001b[0;36mParallel._get_outputs\u001b[0;34m(self, iterator, pre_dispatch)\u001b[0m\n\u001b[1;32m   1647\u001b[0m     \u001b[38;5;28;01myield\u001b[39;00m\n\u001b[1;32m   1649\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backend\u001b[38;5;241m.\u001b[39mretrieval_context():\n\u001b[0;32m-> 1650\u001b[0m         \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_retrieve()\n\u001b[1;32m   1652\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mGeneratorExit\u001b[39;00m:\n\u001b[1;32m   1653\u001b[0m     \u001b[38;5;66;03m# The generator has been garbage collected before being fully\u001b[39;00m\n\u001b[1;32m   1654\u001b[0m     \u001b[38;5;66;03m# consumed. This aborts the remaining tasks if possible and warn\u001b[39;00m\n\u001b[1;32m   1655\u001b[0m     \u001b[38;5;66;03m# the user if necessary.\u001b[39;00m\n\u001b[1;32m   1656\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "File \u001b[0;32m~/miniforge3/envs/group_mention_detection/lib/python3.10/site-packages/joblib/parallel.py:1762\u001b[0m, in \u001b[0;36mParallel._retrieve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1757\u001b[0m \u001b[38;5;66;03m# If the next job is not ready for retrieval yet, we just wait for\u001b[39;00m\n\u001b[1;32m   1758\u001b[0m \u001b[38;5;66;03m# async callbacks to progress.\u001b[39;00m\n\u001b[1;32m   1759\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ((\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jobs) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m   1760\u001b[0m     (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jobs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mget_status(\n\u001b[1;32m   1761\u001b[0m         timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtimeout) \u001b[38;5;241m==\u001b[39m TASK_PENDING)):\n\u001b[0;32m-> 1762\u001b[0m     \u001b[43mtime\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msleep\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0.01\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1763\u001b[0m     \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m   1765\u001b[0m \u001b[38;5;66;03m# We need to be careful: the job list can be filling up as\u001b[39;00m\n\u001b[1;32m   1766\u001b[0m \u001b[38;5;66;03m# we empty it and Python list are not thread-safe by\u001b[39;00m\n\u001b[1;32m   1767\u001b[0m \u001b[38;5;66;03m# default hence the use of the lock\u001b[39;00m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# fit model\n",
    "amodel.fit_predict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "821b1bc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "59246.04437354207\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[inf,\n",
       " 100322.22447962966,\n",
       " 16926.870542862453,\n",
       " 2475.791492738761,\n",
       " 587.673656987492,\n",
       " 241.77287979098037,\n",
       " 202.29510506708175,\n",
       " 151.08837817702442,\n",
       " 134.41390146687627,\n",
       " 147.75591368647292,\n",
       " 141.03665407607332,\n",
       " 164.30731191439554,\n",
       " 143.51948086591437,\n",
       " 135.67917372472584,\n",
       " 112.49956936296076,\n",
       " 78.37210077652708,\n",
       " 49.19128704071045,\n",
       " 30.435816916637123,\n",
       " 23.220850703306496,\n",
       " 17.607850742526352,\n",
       " 13.628571960143745,\n",
       " 11.153746377211064,\n",
       " 9.183980046771467,\n",
       " 8.113782392349094,\n",
       " 6.74266731319949,\n",
       " 6.388222127687186,\n",
       " 6.535607845056802,\n",
       " 6.46002184599638]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(amodel.runtime_)\n",
    "amodel.model.convergence_history"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31caa9b5",
   "metadata": {},
   "source": [
    "### add sentence metadata to posterior label estimates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e788e816",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add metadata\n",
    "\n",
    "# Use str.extract() to extract values from the sentence_id column\n",
    "pattern = r'^([a-z0-9]+)[_-]((\\d{4})(-?\\d{2})?)-(\\d+-\\d+)$'\n",
    "metadata[['party', 'tmp1', 'year', 'month', 'tmp2']] = metadata['sentence_id'].str.extract(pattern)\n",
    "metadata[['paragraph_nr', 'sentence_nr']] = metadata['tmp2'].str.split('-', n=1, expand=True)\n",
    "metadata.drop(['tmp1', 'tmp2'], axis=1, inplace=True)\n",
    "metadata.set_index('uid', drop=True, verify_integrity=True, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "ef99669e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for doc in amodel.corpus.docs:\n",
    "    doc.metadata = {k: v for k, v in dict(metadata.loc[doc.id]).items() if not pd.isna(v)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "3234da91",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ True]), array([8576]))"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(np.array([hasattr(doc, 'metadata') for doc in amodel.corpus.docs]), return_counts=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d844b81",
   "metadata": {},
   "source": [
    "## Write to disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "dbcd0e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _jsonify(self, fields=['id', 'text', 'tokens', 'annotations', 'labels', 'metadata']):\n",
    "    out = {k: self.__dict__[k] for k in fields if hasattr(self, k)}\n",
    "    if 'annotations' in out.keys():\n",
    "        out['annotations'] = {k: v.tolist() for k, v in out['annotations'].items()}\n",
    "    if 'labels' in out.keys():\n",
    "        out['labels'] = {k: v.tolist() for k, v in out['labels'].items()}\n",
    "        if len(out['labels']) == 0:\n",
    "            out['labels'] = None\n",
    "    if 'metadata' in out.keys():\n",
    "        if len(out['metadata']) == 0:\n",
    "            out['metadata'] = None\n",
    "    return json.dumps(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f270ad74",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(args.output_file) or args.overwrite_output:\n",
    "    with open(args.output_file, mode='w', encoding='utf-8') as writer:\n",
    "        for doc in amodel.corpus.docs:\n",
    "            writer.write(_jsonify(doc)+'\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "group_mention_detection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
